import numpy as np
import plotly.graph_objects as go
import plotly.express as px

def objective_function(x, y):
    return x**2 + y**2 + 0.3 * x * y

def gradient(x, y):
    return np.array([2*x + 0.3*y, 2*y + 0.3*x])

def stochastic_gradient_descent(initial_point, learning_rate, num_iterations, noise_scale=0.05, convergence_threshold=0.01):
    points = [initial_point]
    for _ in range(num_iterations):
        grad = gradient(*points[-1])
        noise = noise_scale * np.random.randn(2)  # Adding Gaussian noise
        update = -learning_rate * (grad + noise)
        new_point = points[-1] + update
        points.append(new_point)
        
        # Check convergence
        if np.linalg.norm(update) < convergence_threshold:
            break
    
    return np.array(points)

# Set parameters
initial_point = np.array([10, 10])
final_point = np.array([0, 0])
learning_rate = 0.1
num_iterations = 50
noise_scale = 5  # Adjust this factor to control the noise

# Generate contour data
x_contour = np.linspace(-5, 5, 100)
y_contour = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x_contour, y_contour)
Z = objective_function(X, Y)

# Generate circular patterns for contour plot
contour_levels = np.linspace(Z.min(), Z.max(), 7)
scatter_plots = []
spacing_factor = 1.5

colors = px.colors.sequential.Viridis

# Number of colors needed
num_colors = len(contour_levels)

# Create an array of colors
viridis_colors = [colors[int(i * len(colors) / num_colors)] for i in range(num_colors)]

# ... (your loop)

for i, level in enumerate(contour_levels):
    theta = np.linspace(0, 2*np.pi, 100)
    x_contour = spacing_factor * np.sqrt(level) * np.cos(theta)
    y_contour = spacing_factor * np.sqrt(level) * np.sin(theta)

    scatter = go.Scatter(x=x_contour,
                         y=y_contour,
                         mode='lines',
                         line=dict(width=2, color=viridis_colors[i]),
                         showlegend=False)

    scatter_plots.append(scatter)

# Generate SGD trajectory
trajectory = stochastic_gradient_descent(initial_point, learning_rate, num_iterations, noise_scale=noise_scale)

# Create layout
layout = go.Layout(title='Stochastic Gradient Descent Simulation',
                   xaxis=dict(title=''),
                   yaxis=dict(title=''),
                   xaxis_showgrid=False,
                   yaxis_showgrid=False)

# Create frames for animation
frames = []

for i in range(len(trajectory)):
    frame = go.Frame(
        data=[*scatter_plots, 
              go.Scatter(x=trajectory[:i+1, 0], y=trajectory[:i+1, 1], 
                         mode='lines+markers', line=dict(color='red', dash='dash'), name='SGD Simulation')],
        name=f'SGD Step {i+1}'
    )
    frames.append(frame)

# Create figure
fig = go.Figure(data=scatter_plots, layout=layout, frames=frames)

# Add initial and final point
fig.add_trace(go.Scatter(x=[10], y=[10], mode='markers', marker=dict(color='red'), name='Initial Point'))
fig.add_trace(go.Scatter(x=[final_point[0]], y=[final_point[1]], mode='markers', marker=dict(color='green'), name='Final Point'))
              
# Set animation settings
fig.update_layout(plot_bgcolor='white',
updatemenus=[dict(type='buttons', showactive=False, buttons=[dict(label='Play',
                                        method='animate', args=[None, dict(frame=dict(duration=500, redraw=True), fromcurrent=True, mode='immediate')]),
                                        dict(label='Pause', method='animate', args=[[None], dict(frame=dict(duration=0, redraw=True), mode='immediate')])])])

# Show the plot
fig.show()